





<!DOCTYPE html>
<html class="writer-html5" lang="zh-CN" >
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>如何在CPU上优化GEMM（通用矩阵乘） &mdash; tvm 0.8.dev1982 文档</title>
  

  
  <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
  <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
  <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
  <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
  <link rel="stylesheet" href="../../_static/css/tlcpack_theme.css" type="text/css" />

  
  
    <link rel="shortcut icon" href="../../_static/tvm-logo-square.png"/>
  

  
  
  
  
    
      <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
        <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
        <script src="../../_static/jquery.js"></script>
        <script src="../../_static/underscore.js"></script>
        <script src="../../_static/doctools.js"></script>
        <script src="../../_static/translations.js"></script>
    
    <script type="text/javascript" src="../../_static/js/theme.js"></script>

    
    <script type="text/javascript" src="../../_static/js/tlcpack_theme.js"></script>
    <link rel="index" title="索引" href="../../genindex.html" />
    <link rel="search" title="搜索" href="../../search.html" />
    <link rel="next" title="如何在GPU上优化卷积" href="opt_conv_cuda.html" />
    <link rel="prev" title="优化张量算子" href="index.html" /> 
</head>

<body class="wy-body-for-nav">

   
  <div class="wy-grid-for-nav">
    
    
<header class="header">
    <div class="innercontainer">
      <div class="headerInner d-flex justify-content-between align-items-center">
          <div class="headerLogo">
               <a href="https://tvm.apache.org/"><img src=https://tvm.apache.org/assets/images/logo.svg alt="logo"></a>
          </div>

          <div id="headMenu" class="headerNav">
            <button type="button" id="closeHeadMenu" class="navCloseBtn"><img src="../../_static/img/close-icon.svg" alt="Close"></button>
             <ul class="nav">
                <li class="nav-item">
                   <a class="nav-link" href=https://tvm.apache.org/community>Community</a>
                </li>
                <li class="nav-item">
                   <a class="nav-link" href=https://tvm.apache.org/download>Download</a>
                </li>
                <li class="nav-item">
                   <a class="nav-link" href=https://tvm.apache.org/vta>VTA</a>
                </li>
                <li class="nav-item">
                   <a class="nav-link" href=https://tvm.apache.org/blog>Blog</a>
                </li>
                <li class="nav-item">
                   <a class="nav-link" href=https://tvm.apache.org/docs>Docs</a>
                </li>
                <li class="nav-item">
                   <a class="nav-link" href=https://tvmconf.org>Conference</a>
                </li>
                <li class="nav-item">
                   <a class="nav-link" href=https://github.com/apache/tvm/>Github</a>
                </li>
                <li class="nav-item">
                   <a class="nav-link" href=https://tvmchinese.github.io/declaration_zh_CN.html>About-Translators</a>
                </li>
             </ul>
               <div class="responsivetlcdropdown">
                 <button type="button" class="btn-link">
                   ASF
                 </button>
                 <ul>
                     <li>
                       <a href=https://apache.org/>Apache Homepage</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/licenses/>License</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/security/>Security</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/events/current-event>Events</a>
                     </li>
                     <li>
                       <a href=https://www.zhihu.com/column/c_1429578595417563136>Zhihu</a>
                     </li>
                 </ul>
               </div>
          </div>
            <div class="responsiveMenuIcon">
              <button type="button" id="menuBtn" class="btn-menu"><img src="../../_static/img/menu-icon.svg" alt="Menu Icon"></button>
            </div>

            <div class="tlcDropdown">
              <div class="dropdown">
                <button type="button" class="btn-link dropdown-toggle" data-toggle="dropdown" aria-haspopup="true" aria-expanded="false">
                  ASF
                </button>
                <div class="dropdown-menu dropdown-menu-right">
                  <ul>
                     <li>
                       <a href=https://apache.org/>Apache Homepage</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/licenses/>License</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/foundation/sponsorship.html>Sponsorship</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/security/>Security</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/foundation/thanks.html>Thanks</a>
                     </li>
                     <li>
                       <a href=https://www.apache.org/events/current-event>Events</a>
                     </li>
                     <li>
                       <a href=https://www.zhihu.com/column/c_1429578595417563136>Zhihu</a>
                     </li>
                  </ul>
                </div>
              </div>
          </div>
       </div>
    </div>
 </header>
 
    <nav data-toggle="wy-nav-shift" class="wy-nav-side fixed">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
          

          
            <a href="../../index.html">
          

          
            
            <img src="../../_static/tvm-logo-small.png" class="logo" alt="Logo"/>
          
          </a>

          
            
            
                <div class="version">
                  0.8.dev1982
                </div>
            
          

          
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>

          
        </div>

        
        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
              
            
            
              <p class="caption" role="heading"><span class="caption-text">如何开始</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../install/index.html">安装 TVM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../contribute/index.html">贡献者指南</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">用户引导</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../tutorial/index.html">User Tutorial</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">How To Guides</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../compile_models/index.html">编译深度学习模型</a></li>
<li class="toctree-l2"><a class="reference internal" href="../deploy/index.html">TVM 部署模型和集成</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_relay/index.html">Work With Relay</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_schedules/index.html">Work With Tensor Expression and Schedules</a></li>
<li class="toctree-l2 current"><a class="reference internal" href="index.html">优化张量算子</a><ul class="current">
<li class="toctree-l3 current"><a class="current reference internal" href="#">如何在CPU上优化GEMM（通用矩阵乘）</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#preparation-and-baseline">准备和基线（Baseline）</a></li>
<li class="toctree-l4"><a class="reference internal" href="#blocking">分块</a></li>
<li class="toctree-l4"><a class="reference internal" href="#vectorization">矢量化</a></li>
<li class="toctree-l4"><a class="reference internal" href="#loop-permutation">循环重排</a></li>
<li class="toctree-l4"><a class="reference internal" href="#array-packing">数组打包</a></li>
<li class="toctree-l4"><a class="reference internal" href="#write-cache-for-blocks">块的写缓存。</a></li>
<li class="toctree-l4"><a class="reference internal" href="#parallel">并行</a></li>
<li class="toctree-l4"><a class="reference internal" href="#summary">总结</a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="opt_conv_cuda.html">如何在GPU上优化卷积</a></li>
<li class="toctree-l3"><a class="reference internal" href="opt_conv_tensorcore.html">How to optimize convolution using TensorCores</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autotvm/index.html">Auto-Tune with Templates and AutoTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../tune_with_autoscheduler/index.html">Use AutoScheduler for Template-Free Scheduling</a></li>
<li class="toctree-l2"><a class="reference internal" href="../work_with_microtvm/index.html">Work With microTVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../extend_tvm/index.html">Extend TVM</a></li>
<li class="toctree-l2"><a class="reference internal" href="../profile/index.html">Profile Models</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../errors.html">Handle TVM Errors</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq.html">常见提问</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">开发者引导</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../dev/tutorial/index.html">Developer Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../dev/how_to/how_to.html">开发者指南</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">架构指南</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../arch/index.html">Design and Architecture</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">主题引导</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../topic/microtvm/index.html">microTVM：裸机使用TVM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../topic/vta/index.html">VTA: Versatile Tensor Accelerator</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">参考指南</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../reference/langref/index.html">语言参考</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/python/index.html">Python API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/api/links.html">Other APIs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/publications.html">Publications</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../genindex.html">索引</a></li>
</ul>

            
          
        </div>
        
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
      
      <nav class="wy-nav-top" aria-label="top navigation" data-toggle="wy-nav-top">
        
            <div class="togglemenu">

            </div>
            <div class="nav-content">
              <!-- tvm -->
              Table of content
            </div>
        
      </nav>


      <div class="wy-nav-content">
        
        <div class="rst-content">
        

          




















<div role="navigation" aria-label="breadcrumbs navigation">

  <ul class="wy-breadcrumbs">
    
      <li><a href="../../index.html">Docs</a> <span class="br-arrow">></span></li>
        
          <li><a href="../index.html">How To Guides</a> <span class="br-arrow">></span></li>
        
          <li><a href="index.html">优化张量算子</a> <span class="br-arrow">></span></li>
        
      <li>如何在CPU上优化GEMM（通用矩阵乘）</li>
    
    
      <li class="wy-breadcrumbs-aside">
        
            
            <a href="../../_sources/how_to/optimize_operators/opt_gemm.rst.txt" rel="nofollow"> <img src="../../_static//img/source.svg" alt="viewsource"/></a>
          
        
      </li>
    
  </ul>

  
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">注解</p>
<p>点击 <a class="reference internal" href="#sphx-glr-download-how-to-optimize-operators-opt-gemm-py"><span class="std std-ref">h这里</span></a> 下载完整的样例代码</p>
</div>
<div class="sphx-glr-example-title section" id="how-to-optimize-gemm-on-cpu">
<span id="opt-gemm"></span><span id="sphx-glr-how-to-optimize-operators-opt-gemm-py"></span><h1>如何在CPU上优化GEMM（通用矩阵乘）<a class="headerlink" href="#how-to-optimize-gemm-on-cpu" title="永久链接至标题">¶</a></h1>
<p><strong>作者</strong>: <a class="reference external" href="https://github.com/were">Jian Weng</a>, <a class="reference external" href="https://github.com/yuruofeifei">Ruofei Yu</a></p>
<p>TVM提供了一个抽象接口允许用户分别描述算法和算法的组织方式（即所谓的schedule）。通常写一个高性能的schedule会破坏代码的可读性和模块化。另外，尝试各种看起来有前景的schedule也是很耗时的。基于TVM，我们可以有效的尝试这些schedule来提升性能。</p>
<p>在本教程中，我们将演示如何使用 TVM 优化形矩阵乘法，并通过简单地添加 18 行额外代码实现比 baseline 快 200 倍。</p>
<dl class="simple">
<dt>在 CPU 上执行的密集计算应用程序有两个重要的优化：</dt><dd><ol class="arabic simple">
<li><p>提高内存访问的缓存命中率。复杂的数值计算和热点内存访问都可以通过高缓存命中率来加速。 这需要我们将原始内存访问模式转换为适合缓存策略的模式。</p></li>
<li><p>SIMD（单指令多数据），或者我们称之为向量处理单元。 每次都会处理一小批数据，而不是单个数据。 这需要我们统一变换循环体中的数据访问模式，以便 LLVM 后端可以将其lower到SIMD。</p></li>
</ol>
</dd>
</dl>
<p>实际上，本教程中使用的所有方法都是这个 <a class="reference external" href="https://github.com/flame/how-to-optimize-gemm">repo</a> 中提到的技巧的一个子集。 其中一些已经被TVM自动抽象所采用，但由于 TVM 的限制，其中一些不能简单地应用。</p>
<p>下面提到的所有实验结果，都是在配置了 Intel i7-4770HQ CPU 的 2015 年的 15’ MacBook 上执行的。 对于所有 x86 CPU，缓存行大小应为 64 个字节。</p>
<div class="section" id="preparation-and-baseline">
<h2>准备和基线（Baseline）<a class="headerlink" href="#preparation-and-baseline" title="永久链接至标题">¶</a></h2>
<p>在本教程中，我们将演示如何使用 TVM 优化矩阵乘法。 在实际演示之前，我们首先定义这些变量。 然后我们编写一个基线实现，这是在 TVM 中编写矩阵乘法的最简单方法。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tvm</span>
<span class="kn">import</span> <span class="nn">tvm.testing</span>
<span class="kn">from</span> <span class="nn">tvm</span> <span class="k">import</span> <span class="n">te</span>
<span class="kn">import</span> <span class="nn">numpy</span>
<span class="kn">import</span> <span class="nn">timeit</span>

<span class="c1"># The size of the matrix</span>
<span class="c1"># (M, K) x (K, N)</span>
<span class="c1"># You are free to try out different shapes, sometimes TVM optimization outperforms numpy with MKL.</span>
<span class="n">M</span> <span class="o">=</span> <span class="mi">1024</span>
<span class="n">K</span> <span class="o">=</span> <span class="mi">1024</span>
<span class="n">N</span> <span class="o">=</span> <span class="mi">1024</span>

<span class="c1"># The default tensor type in tvm</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;float32&quot;</span>

<span class="c1"># using Intel AVX2(Advanced Vector Extensions) ISA for SIMD</span>
<span class="c1"># To get the best performance, please change the following line</span>
<span class="c1"># to llvm -mcpu=core-avx2, or specific type of CPU you use</span>
<span class="n">target</span> <span class="o">=</span> <span class="s2">&quot;llvm&quot;</span>
<span class="n">dev</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

<span class="c1"># Random generated tensor for testing</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>

<span class="n">np_repeat</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">np_runing_time</span> <span class="o">=</span> <span class="n">timeit</span><span class="o">.</span><span class="n">timeit</span><span class="p">(</span>
    <span class="n">setup</span><span class="o">=</span><span class="s2">&quot;import numpy</span><span class="se">\n</span><span class="s2">&quot;</span>
    <span class="s2">&quot;M = &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">M</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span>
    <span class="s2">&quot;K = &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">K</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span>
    <span class="s2">&quot;N = &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">N</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span>
    <span class="s1">&#39;dtype = &quot;float32&quot;</span><span class="se">\n</span><span class="s1">&#39;</span>
    <span class="s2">&quot;a = numpy.random.rand(M, K).astype(dtype)</span><span class="se">\n</span><span class="s2">&quot;</span>
    <span class="s2">&quot;b = numpy.random.rand(K, N).astype(dtype)</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">,</span>
    <span class="n">stmt</span><span class="o">=</span><span class="s2">&quot;answer = numpy.dot(a, b)&quot;</span><span class="p">,</span>
    <span class="n">number</span><span class="o">=</span><span class="n">np_repeat</span><span class="p">,</span>
<span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Numpy running time: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">np_runing_time</span> <span class="o">/</span> <span class="n">np_repeat</span><span class="p">))</span>

<span class="n">answer</span> <span class="o">=</span> <span class="n">numpy</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">b</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>

<span class="c1"># Algorithm</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">reduce_axis</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="s2">&quot;k&quot;</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">placeholder</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;A&quot;</span><span class="p">)</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">placeholder</span><span class="p">((</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;B&quot;</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">compute</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="k">lambda</span> <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span> <span class="n">te</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">n</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="n">k</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;C&quot;</span><span class="p">)</span>

<span class="c1"># Default schedule</span>
<span class="n">s</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mmult&quot;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">func</span>

<span class="n">c</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">answer</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>

<span class="n">evaluator</span> <span class="o">=</span> <span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span><span class="p">(</span><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Baseline: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Numpy running time: 0.017804
Baseline: 2.980112
</pre></div>
</div>
<p>在 TVM 中，我们始终可以检查较低级别的 IR 以调试或优化我们的schedule。 这是使用我们的基线schedule生成的 IR。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>primfn(A_1: handle, B_1: handle, C_1: handle) -&gt; ()
  attr = {&quot;from_legacy_te_schedule&quot;: True, &quot;global_symbol&quot;: &quot;main&quot;, &quot;tir.noalias&quot;: True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (m: int32, 0, 1024) {
    for (n: int32, 0, 1024) {
      C_2[((m*1024) + n)] = 0f32
      for (k: int32, 0, 1024) {
        C_2[((m*1024) + n)] = ((float32*)C_2[((m*1024) + n)] + ((float32*)A_2[((m*1024) + k)]*(float32*)B_2[((k*1024) + n)]))
      }
    }
  }
}
</pre></div>
</div>
</div>
<div class="section" id="blocking">
<h2>分块<a class="headerlink" href="#blocking" title="永久链接至标题">¶</a></h2>
<p>提高缓存命中率的一个重要技巧是分块——数据块将逐块进行计算。 块内部的内存访问是一个具有高内存局部性的小邻域。 在本教程中，我选择了 32 作为分块因子。 所以该块将填充 32 * 32 * sizeof(float) 即 4KB 到缓存中，缓存的总大小为 32KB（L1 数据缓存）</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">bn</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">kfactor</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">s</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span>

<span class="c1"># Blocking by loop tiling</span>
<span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ni</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">bn</span><span class="p">,</span> <span class="n">bn</span><span class="p">)</span>
<span class="p">(</span><span class="n">kaxis</span><span class="p">,)</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span>
<span class="n">ko</span><span class="p">,</span> <span class="n">ki</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">kaxis</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="n">kfactor</span><span class="p">)</span>

<span class="c1"># Hoist reduction domain outside the blocking loop</span>
<span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">ko</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ni</span><span class="p">)</span>

<span class="n">func</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mmult&quot;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">func</span>

<span class="n">c</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">answer</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>

<span class="c1"># By simply tiling the loop 32x32, and hoisting ko, ki outside the blocking loops,</span>
<span class="c1"># we can see big speedup compared with the baseline.</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span><span class="p">(</span><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Opt1: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Opt1: 0.215588
</pre></div>
</div>
<p>这是分块之后产生的 IR。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>primfn(A_1: handle, B_1: handle, C_1: handle) -&gt; ()
  attr = {&quot;from_legacy_te_schedule&quot;: True, &quot;global_symbol&quot;: &quot;main&quot;, &quot;tir.noalias&quot;: True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (m.outer: int32, 0, 32) {
    for (n.outer: int32, 0, 32) {
      for (m.inner.init: int32, 0, 32) {
        for (n.inner.init: int32, 0, 32) {
          C_2[((((m.outer*32768) + (m.inner.init*1024)) + (n.outer*32)) + n.inner.init)] = 0f32
        }
      }
      for (k.outer: int32, 0, 256) {
        for (k.inner: int32, 0, 4) {
          for (m.inner: int32, 0, 32) {
            for (n.inner: int32, 0, 32) {
              C_2[((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)) + n.inner)] = ((float32*)C_2[((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)) + n.inner)] + ((float32*)A_2[((((m.outer*32768) + (m.inner*1024)) + (k.outer*4)) + k.inner)]*(float32*)B_2[((((k.outer*4096) + (k.inner*1024)) + (n.outer*32)) + n.inner)]))
            }
          }
        }
      }
    }
  }
}
</pre></div>
</div>
</div>
<div class="section" id="vectorization">
<h2>矢量化<a class="headerlink" href="#vectorization" title="永久链接至标题">¶</a></h2>
<p>另一个重要的技巧是矢量化。 当内存访问模式一致时，编译器可以检测到这种模式并将连续内存传递给向量处理器。 在 TVM 中，我们可以使用 <cite>vectorize</cite> 接口来提示编译器这种模式，这样我们就可以大大加速它。</p>
<p>在本教程中，我们选择矢量化内循环行数据，因为它是缓存友好的。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">s</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span>
<span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ni</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">bn</span><span class="p">,</span> <span class="n">bn</span><span class="p">)</span>
<span class="p">(</span><span class="n">kaxis</span><span class="p">,)</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span>
<span class="n">ko</span><span class="p">,</span> <span class="n">ki</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">kaxis</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="n">kfactor</span><span class="p">)</span>

<span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">ko</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ni</span><span class="p">)</span>

<span class="c1"># Vectorization</span>
<span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">ni</span><span class="p">)</span>

<span class="n">func</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mmult&quot;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">func</span>

<span class="n">c</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">answer</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>

<span class="n">evaluator</span> <span class="o">=</span> <span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span><span class="p">(</span><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Opt2: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Opt2: 0.240382
</pre></div>
</div>
<p>这是矢量化之后产生的IR。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>primfn(A_1: handle, B_1: handle, C_1: handle) -&gt; ()
  attr = {&quot;from_legacy_te_schedule&quot;: True, &quot;global_symbol&quot;: &quot;main&quot;, &quot;tir.noalias&quot;: True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (m.outer: int32, 0, 32) {
    for (n.outer: int32, 0, 32) {
      for (m.inner.init: int32, 0, 32) {
        C_2[ramp((((m.outer*32768) + (m.inner.init*1024)) + (n.outer*32)), 1, 32)] = broadcast(0f32, 32)
      }
      for (k.outer: int32, 0, 256) {
        for (k.inner: int32, 0, 4) {
          for (m.inner: int32, 0, 32) {
            C_2[ramp((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)), 1, 32)] = ((float32x32*)C_2[ramp((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.inner*1024)) + (k.outer*4)) + k.inner)], 32)*(float32x32*)B_2[ramp((((k.outer*4096) + (k.inner*1024)) + (n.outer*32)), 1, 32)]))
          }
        }
      }
    }
  }
}
</pre></div>
</div>
</div>
<div class="section" id="loop-permutation">
<h2>循环重排<a class="headerlink" href="#loop-permutation" title="永久链接至标题">¶</a></h2>
<p>如果我们查看上面的 IR，我们可以看到 B 和 C 的内循环行数据都进行了向量化。接下来我们将查看 A 的访问模式。在当前schedule中，A 是逐列访问的，这对缓存不友好。 如果我们改变 ki 和内轴 mi 的嵌套循环顺序，A 矩阵的访问模式会对缓存更加友好。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">s</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span>
<span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ni</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">bn</span><span class="p">,</span> <span class="n">bn</span><span class="p">)</span>
<span class="p">(</span><span class="n">kaxis</span><span class="p">,)</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span>
<span class="n">ko</span><span class="p">,</span> <span class="n">ki</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">kaxis</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="n">kfactor</span><span class="p">)</span>

<span class="c1"># re-ordering</span>
<span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">ko</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">ni</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">ni</span><span class="p">)</span>

<span class="n">func</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mmult&quot;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">func</span>

<span class="n">c</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">answer</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>

<span class="n">evaluator</span> <span class="o">=</span> <span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span><span class="p">(</span><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Opt3: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Opt3: 0.095810
</pre></div>
</div>
<p>这是循环重排之后产生的IR。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>primfn(A_1: handle, B_1: handle, C_1: handle) -&gt; ()
  attr = {&quot;from_legacy_te_schedule&quot;: True, &quot;global_symbol&quot;: &quot;main&quot;, &quot;tir.noalias&quot;: True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (m.outer: int32, 0, 32) {
    for (n.outer: int32, 0, 32) {
      for (m.inner.init: int32, 0, 32) {
        C_2[ramp((((m.outer*32768) + (m.inner.init*1024)) + (n.outer*32)), 1, 32)] = broadcast(0f32, 32)
      }
      for (k.outer: int32, 0, 256) {
        for (m.inner: int32, 0, 32) {
          for (k.inner: int32, 0, 4) {
            C_2[ramp((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)), 1, 32)] = ((float32x32*)C_2[ramp((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.inner*1024)) + (k.outer*4)) + k.inner)], 32)*(float32x32*)B_2[ramp((((k.outer*4096) + (k.inner*1024)) + (n.outer*32)), 1, 32)]))
          }
        }
      }
    }
  }
}
</pre></div>
</div>
</div>
<div class="section" id="array-packing">
<h2>数组打包<a class="headerlink" href="#array-packing" title="永久链接至标题">¶</a></h2>
<p>另一个重要的技巧是数组打包。 技巧是对多维数组的存储顺序进行重排，以便在将其展平并存储在一维内存中后按顺序访问。</p>
<img alt="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/array-packing.png" class="align-center" src="https://github.com/dmlc/web-data/raw/main/tvm/tutorial/array-packing.png" />
<p>注意：此图是数组打包工作原理的一般性说明。</p>
<p>我们可以使用数组打包来解决 B 的访问模式。观察展平后 B 的数组访问模式，当我们在 K 维度上迭代时，这不是连续的。 我们可以用维度 [K][N] 重排 B，使其具有维度 [N/bn][K][bn]，其中 bn 是分块因子，也是内循环中 B 的向量大小。 这种重新排序将 N 分成两个维度 — bigN (N/bn) 和 littleN (bn) — 并且新维度 [N/bn][K][bn] 匹配 B 从外循环到内循环的索引（no, ko, ki, ni) 所以在B被展平后内存访问是连续的。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># We have to re-write the algorithm slightly.</span>
<span class="n">packedB</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">compute</span><span class="p">(</span>
    <span class="p">(</span><span class="n">N</span> <span class="o">/</span> <span class="n">bn</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">bn</span><span class="p">),</span> <span class="k">lambda</span> <span class="n">bigN</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">littleN</span><span class="p">:</span> <span class="n">B</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">bigN</span> <span class="o">*</span> <span class="n">bn</span> <span class="o">+</span> <span class="n">littleN</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;packedB&quot;</span>
<span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">compute</span><span class="p">(</span>
    <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span>
    <span class="k">lambda</span> <span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span> <span class="n">te</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">packedB</span><span class="p">[</span><span class="n">n</span> <span class="o">//</span> <span class="n">bn</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">tvm</span><span class="o">.</span><span class="n">tir</span><span class="o">.</span><span class="n">indexmod</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">bn</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=</span><span class="n">k</span><span class="p">),</span>
    <span class="n">name</span><span class="o">=</span><span class="s2">&quot;C&quot;</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">s</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span>

<span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ni</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">bn</span><span class="p">,</span> <span class="n">bn</span><span class="p">)</span>
<span class="p">(</span><span class="n">kaxis</span><span class="p">,)</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span>
<span class="n">ko</span><span class="p">,</span> <span class="n">ki</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">kaxis</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="n">kfactor</span><span class="p">)</span>

<span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">ko</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">ni</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">ni</span><span class="p">)</span>

<span class="n">bigN</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">littleN</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>
<span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">littleN</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">parallel</span><span class="p">(</span><span class="n">bigN</span><span class="p">)</span>

<span class="n">func</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mmult&quot;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">func</span>

<span class="n">c</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">answer</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>

<span class="n">evaluator</span> <span class="o">=</span> <span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span><span class="p">(</span><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Opt4: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Opt4: 0.088119
</pre></div>
</div>
<p>这是执行了数组打包之后产生的IR。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>primfn(A_1: handle, B_1: handle, C_1: handle) -&gt; ()
  attr = {&quot;from_legacy_te_schedule&quot;: True, &quot;global_symbol&quot;: &quot;main&quot;, &quot;tir.noalias&quot;: True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  allocate(packedB: Pointer(global float32x32), float32x32, [32768]), storage_scope = global {
    for (bigN: int32, 0, 32) &quot;parallel&quot; {
      for (k: int32, 0, 1024) {
        packedB[ramp(((bigN*32768) + (k*32)), 1, 32)] = (float32x32*)B_2[ramp(((k*1024) + (bigN*32)), 1, 32)]
      }
    }
    for (m.outer: int32, 0, 32) {
      for (n.outer: int32, 0, 32) {
        for (m.inner.init: int32, 0, 32) {
          C_2[ramp((((m.outer*32768) + (m.inner.init*1024)) + (n.outer*32)), 1, 32)] = broadcast(0f32, 32)
        }
        for (k.outer: int32, 0, 256) {
          for (m.inner: int32, 0, 32) {
            for (k.inner: int32, 0, 4) {
              C_2[ramp((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)), 1, 32)] = ((float32x32*)C_2[ramp((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.inner*1024)) + (k.outer*4)) + k.inner)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + (k.inner*32)), 1, 32)]))
            }
          }
        }
      }
    }
  }
}
</pre></div>
</div>
</div>
<div class="section" id="write-cache-for-blocks">
<h2>块的写缓存。<a class="headerlink" href="#write-cache-for-blocks" title="永久链接至标题">¶</a></h2>
<p>分块后，程序将结果逐块写入C，访问模式不是顺序的。 所以我们可以使用一个顺序缓存数组来保存块结果并在所有块的结果准备好时写入 C。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">s</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span>

<span class="c1"># Allocate write cache</span>
<span class="n">CC</span> <span class="o">=</span> <span class="n">s</span><span class="o">.</span><span class="n">cache_write</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="s2">&quot;global&quot;</span><span class="p">)</span>

<span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ni</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">bn</span><span class="p">,</span> <span class="n">bn</span><span class="p">)</span>

<span class="c1"># Write cache is computed at no</span>
<span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">],</span> <span class="n">no</span><span class="p">)</span>

<span class="c1"># New inner axes</span>
<span class="n">mc</span><span class="p">,</span> <span class="n">nc</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>

<span class="p">(</span><span class="n">kaxis</span><span class="p">,)</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span>
<span class="n">ko</span><span class="p">,</span> <span class="n">ki</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">kaxis</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="n">kfactor</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">ko</span><span class="p">,</span> <span class="n">mc</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">nc</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">nc</span><span class="p">)</span>

<span class="c1"># TODO: Add separate optimization step to discuss loop unrolloing</span>
<span class="c1"># unrolling is a loop optimization strategy which can reduce branch</span>
<span class="c1"># prediction failures and increases the chance of concurrent execution</span>
<span class="c1"># unroll kfactor loops</span>
<span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">ki</span><span class="p">)</span>

<span class="n">bigN</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">littleN</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>
<span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">littleN</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">parallel</span><span class="p">(</span><span class="n">bigN</span><span class="p">)</span>

<span class="n">func</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mmult&quot;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">func</span>

<span class="n">c</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">answer</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>

<span class="n">evaluator</span> <span class="o">=</span> <span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span><span class="p">(</span><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Opt5: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Opt5: 0.087136
</pre></div>
</div>
<p>这是分块之后产生的 IR。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>primfn(A_1: handle, B_1: handle, C_1: handle) -&gt; ()
  attr = {&quot;from_legacy_te_schedule&quot;: True, &quot;global_symbol&quot;: &quot;main&quot;, &quot;tir.noalias&quot;: True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  allocate(packedB: Pointer(global float32x32), float32x32, [32768]), storage_scope = global;
  allocate(C.global: Pointer(global float32), float32, [1024]), storage_scope = global {
    for (bigN: int32, 0, 32) &quot;parallel&quot; {
      for (k: int32, 0, 1024) {
        packedB[ramp(((bigN*32768) + (k*32)), 1, 32)] = (float32x32*)B_2[ramp(((k*1024) + (bigN*32)), 1, 32)]
      }
    }
    for (m.outer: int32, 0, 32) {
      for (n.outer: int32, 0, 32) {
        for (m.c.init: int32, 0, 32) {
          C.global[ramp((m.c.init*32), 1, 32)] = broadcast(0f32, 32)
        }
        for (k.outer: int32, 0, 256) {
          for (m.c: int32, 0, 32) {
            C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[(((m.outer*32768) + (m.c*1024)) + (k.outer*4))], 32)*(float32x32*)packedB[ramp(((n.outer*32768) + (k.outer*128)), 1, 32)]))
            C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 1)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 32), 1, 32)]))
            C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 2)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 64), 1, 32)]))
            C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 3)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 96), 1, 32)]))
          }
        }
        for (m.inner: int32, 0, 32) {
          for (n.inner: int32, 0, 32) {
            C_2[((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)) + n.inner)] = (float32*)C.global[((m.inner*32) + n.inner)]
          }
        }
      }
    }
  }
}
</pre></div>
</div>
</div>
<div class="section" id="parallel">
<h2>并行<a class="headerlink" href="#parallel" title="永久链接至标题">¶</a></h2>
<p>此外，我们还可以利用多核处理器进行线程级并行化。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">s</span> <span class="o">=</span> <span class="n">te</span><span class="o">.</span><span class="n">create_schedule</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="p">)</span>

<span class="n">CC</span> <span class="o">=</span> <span class="n">s</span><span class="o">.</span><span class="n">cache_write</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="s2">&quot;global&quot;</span><span class="p">)</span>

<span class="n">mo</span><span class="p">,</span> <span class="n">no</span><span class="p">,</span> <span class="n">mi</span><span class="p">,</span> <span class="n">ni</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">C</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">bn</span><span class="p">,</span> <span class="n">bn</span><span class="p">)</span>

<span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">compute_at</span><span class="p">(</span><span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">],</span> <span class="n">no</span><span class="p">)</span>

<span class="n">mc</span><span class="p">,</span> <span class="n">nc</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>

<span class="p">(</span><span class="n">kaxis</span><span class="p">,)</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">reduce_axis</span>
<span class="n">ko</span><span class="p">,</span> <span class="n">ki</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">kaxis</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="n">kfactor</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">ko</span><span class="p">,</span> <span class="n">mc</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">nc</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">nc</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">CC</span><span class="p">]</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">ki</span><span class="p">)</span>

<span class="c1"># parallel</span>
<span class="n">s</span><span class="p">[</span><span class="n">C</span><span class="p">]</span><span class="o">.</span><span class="n">parallel</span><span class="p">(</span><span class="n">mo</span><span class="p">)</span>

<span class="n">bigN</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">littleN</span> <span class="o">=</span> <span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">axis</span>
<span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">littleN</span><span class="p">)</span>
<span class="n">s</span><span class="p">[</span><span class="n">packedB</span><span class="p">]</span><span class="o">.</span><span class="n">parallel</span><span class="p">(</span><span class="n">bigN</span><span class="p">)</span>

<span class="n">func</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mmult&quot;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">func</span>

<span class="n">c</span> <span class="o">=</span> <span class="n">tvm</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">numpy</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span> <span class="n">dev</span><span class="p">)</span>
<span class="n">func</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="n">tvm</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">answer</span><span class="p">,</span> <span class="n">rtol</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>

<span class="n">evaluator</span> <span class="o">=</span> <span class="n">func</span><span class="o">.</span><span class="n">time_evaluator</span><span class="p">(</span><span class="n">func</span><span class="o">.</span><span class="n">entry_name</span><span class="p">,</span> <span class="n">dev</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
<span class="n">opt6_time</span> <span class="o">=</span> <span class="n">evaluator</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Opt6: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">opt6_time</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>Opt6: 0.108619
</pre></div>
</div>
<p>这是并行之后产生的IR。</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">tvm</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="p">[</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">],</span> <span class="n">simple_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">输出:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>primfn(A_1: handle, B_1: handle, C_1: handle) -&gt; ()
  attr = {&quot;from_legacy_te_schedule&quot;: True, &quot;global_symbol&quot;: &quot;main&quot;, &quot;tir.noalias&quot;: True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  allocate(packedB: Pointer(global float32x32), float32x32, [32768]), storage_scope = global {
    for (bigN: int32, 0, 32) &quot;parallel&quot; {
      for (k: int32, 0, 1024) {
        packedB[ramp(((bigN*32768) + (k*32)), 1, 32)] = (float32x32*)B_2[ramp(((k*1024) + (bigN*32)), 1, 32)]
      }
    }
    for (m.outer: int32, 0, 32) &quot;parallel&quot; {
      allocate(C.global: Pointer(global float32), float32, [1024]), storage_scope = global;
      for (n.outer: int32, 0, 32) {
        for (m.c.init: int32, 0, 32) {
          C.global[ramp((m.c.init*32), 1, 32)] = broadcast(0f32, 32)
        }
        for (k.outer: int32, 0, 256) {
          for (m.c: int32, 0, 32) {
            C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[(((m.outer*32768) + (m.c*1024)) + (k.outer*4))], 32)*(float32x32*)packedB[ramp(((n.outer*32768) + (k.outer*128)), 1, 32)]))
            C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 1)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 32), 1, 32)]))
            C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 2)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 64), 1, 32)]))
            C.global[ramp((m.c*32), 1, 32)] = ((float32x32*)C.global[ramp((m.c*32), 1, 32)] + (broadcast((float32*)A_2[((((m.outer*32768) + (m.c*1024)) + (k.outer*4)) + 3)], 32)*(float32x32*)packedB[ramp((((n.outer*32768) + (k.outer*128)) + 96), 1, 32)]))
          }
        }
        for (m.inner: int32, 0, 32) {
          for (n.inner: int32, 0, 32) {
            C_2[((((m.outer*32768) + (m.inner*1024)) + (n.outer*32)) + n.inner)] = (float32*)C.global[((m.inner*32) + n.inner)]
          }
        }
      }
    }
  }
}
</pre></div>
</div>
</div>
<div class="section" id="summary">
<h2>总结<a class="headerlink" href="#summary" title="永久链接至标题">¶</a></h2>
<p>在仅用 18 行代码应用了上述的优化技巧后，我们生成的代码的性能可以达到使用 MKL 实现 的numpy 的 60%。 请注意，网页上的输出展示了一个非独家 Docker 容器上的运行时间，因此它们*不可靠*。 强烈建议您自己运行本教程，以观察 TVM 实现的性能提升。</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-how-to-optimize-operators-opt-gemm-py">
<div class="sphx-glr-download docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/96137df89d8034b548f407123ec50ce9/opt_gemm.py"><code class="xref download docutils literal notranslate"><span class="pre">下载</span> <span class="pre">Python</span> <span class="pre">源代码:</span> <span class="pre">opt_gemm.py</span></code></a></p>
</div>
<div class="sphx-glr-download docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/0f8d36b3ffd04a5a08089dc671eb788e/opt_gemm.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">下载</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">opt_gemm.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>


           </div>
           
          </div>
          

<footer>

    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="opt_conv_cuda.html" class="btn btn-neutral float-right" title="如何在GPU上优化卷积" accesskey="n" rel="next">下一个 <span class="fa fa-arrow-circle-right"></span></a>
      
      
        <a href="index.html" class="btn btn-neutral float-left" title="优化张量算子" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> 上一个</a>
      
    </div>

<div id="button" class="backtop"><img src="../../_static//img/right.svg" alt="backtop"/> </div>
<section class="footerSec">
    <div class="footerHeader">
      <ul class="d-flex align-md-items-center justify-content-between flex-column flex-md-row">
        <li class="copywrite d-flex align-items-center">
          <h5 id="copy-right-info">© 2020 Apache Software Foundation | All right reserved</h5>
        </li>
      </ul>

    </div>

    <ul>
      <li class="footernote">Copyright © 2020 The Apache Software Foundation. Apache TVM, Apache, the Apache feather, and the Apache TVM project logo are either trademarks or registered trademarks of the Apache Software Foundation.</li>
    </ul>

</section>
</footer>
        </div>
      </div>

    </section>

  </div>
  

    <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
    <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>

  </body>
  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script>

  
  
    
    <!-- Theme Analytics -->
    <script>
    (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
      (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
      m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
    })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');

    ga('create', 'UA-75982049-2', 'auto');
    ga('send', 'pageview');
    </script>

    
   

</body>
</html>